{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "39329df3-1f99-4b11-9405-5969d52368a7",
   "metadata": {},
   "source": [
    "# Decision Tree & Successive Halving Random + Search Example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "7f61a90e-a119-4bd0-af21-38604c5b4eec",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "scikit-learn: 1.0\n",
      "mlxtend     : 0.19.0\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark -p scikit-learn,mlxtend"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f0489c2-dd9c-4e71-a78c-e01201762b37",
   "metadata": {},
   "source": [
    "## Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "271b17ff-5ea4-4161-8b7f-20ba8131d666",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train/Valid/Test sizes: 398 80 171\n"
     ]
    }
   ],
   "source": [
    "from sklearn import model_selection\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn import datasets\n",
    "\n",
    "\n",
    "data = datasets.load_breast_cancer()\n",
    "X, y = data.data, data.target\n",
    "\n",
    "X_train, X_test, y_train, y_test = \\\n",
    "    train_test_split(X, y, test_size=0.3, random_state=1, stratify=y)\n",
    "\n",
    "X_train_sub, X_valid, y_train_sub, y_valid = \\\n",
    "    train_test_split(X_train, y_train, test_size=0.2, random_state=1, stratify=y_train)\n",
    "\n",
    "print('Train/Valid/Test sizes:', y_train.shape[0], y_valid.shape[0], y_test.shape[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c922b01-86f0-4e83-9e36-446f99f6fe1b",
   "metadata": {},
   "source": [
    "## Successive Halving + Random Search"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "72e56f33-ec33-46dd-afa2-a1b3c8b3da0b",
   "metadata": {},
   "source": [
    "\n",
    "- More info: \n",
    "  - https://scikit-learn.org/stable/modules/grid_search.html#successive-halving-user-guide\n",
    "  - https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.HalvingRandomSearchCV.html#sklearn.model_selection.HalvingRandomSearchCV"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "96f0b4c1-803a-436f-93d5-31baab55faa5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8882539682539681"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "import scipy.stats\n",
    "\n",
    "from sklearn.experimental import enable_halving_search_cv\n",
    "from sklearn.model_selection import HalvingRandomSearchCV\n",
    "\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "\n",
    "\n",
    "clf = DecisionTreeClassifier(random_state=123)\n",
    "\n",
    "params =  {\n",
    "    'min_samples_split': scipy.stats.randint(2, 12),\n",
    "    'min_impurity_decrease': scipy.stats.uniform(0.0, 0.5),\n",
    "    'max_depth': [6, 16, None]\n",
    "}\n",
    "\n",
    "\n",
    "search = HalvingRandomSearchCV(\n",
    "    estimator=clf, \n",
    "    param_distributions=params,\n",
    "    n_candidates='exhaust',\n",
    "    resource='n_samples',\n",
    "    factor=3,\n",
    "    random_state=123,\n",
    "    n_jobs=1)\n",
    "\n",
    "\n",
    "search.fit(X_train, y_train)\n",
    "\n",
    "search.best_score_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "2c26399d-ebfc-4b06-86d9-36e49711e908",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'max_depth': None,\n",
       " 'min_impurity_decrease': 0.029838948304784174,\n",
       " 'min_samples_split': 2}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "search.best_params_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "763e816b-6437-45a9-812f-8b429472d75e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training Accuracy: 0.95\n",
      "Test Accuracy: 0.94\n"
     ]
    }
   ],
   "source": [
    "print(f\"Training Accuracy: {search.best_estimator_.score(X_train, y_train):0.2f}\")\n",
    "print(f\"Test Accuracy: {search.best_estimator_.score(X_test, y_test):0.2f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
